# -*- coding: utf-8 -*-
"""
Created on Mon Apr 20 10:28:50 2020

@author: xiang_yaobing
"""

import os 
import random
from PIL import Image
from tensorflow.keras.models import Sequential,load_model
from tensorflow.keras.layers import LSTM,Dense,Activation,SimpleRNN,Conv2D,MaxPool2D,Flatten,Reshape,Dropout
from tensorflow.keras.callbacks import EarlyStopping,ModelCheckpoint
from tensorflow.keras.metrics import categorical_accuracy
from tensorflow.keras.optimizers import RMSprop
import numpy as np
from tensorflow.keras import utils
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import cv2

def load_data(path1, path2):
#读入文件列表，path代表2个类别的文件路径
    filelist1 = [os.path.join(path1, f) for f in os.listdir(path1)]
    filelist2 = [os.path.join(path2, f) for f in os.listdir(path2)]
    x_test = []
    y_test = []
    n1 = len(filelist1)
    n2 = len(filelist2)
    for img in filelist1:
        im = np.array(cv2.imread(img,cv2.IMREAD_UNCHANGED))#读图像65535
#        m = im.copy()
#        im-im.min()
        im = np.array((im-im.min())/im.max())#im = np.array(np.loadtxt(img))
        #im = im.flatten()
        x_test.append(im)
    for img in filelist2:
        im = np.array(cv2.imread(img,cv2.IMREAD_UNCHANGED))
        im = np.array((im-im.min())/im.max())
        #print(img)
        #im = np.array(np.loadtxt(img))
        #im = im.flatten()
        x_test.append(im)
    #x_test = image_standerization(np.array(x_test))#数据标准化
    #自己造标签 总共2类，所以标签是01
    y_test = np.zeros((n1 + n2), dtype=int)
    for i in range(n1): 
        y_test[i] = 0#恒星\噪声
    for i in range(n2):
        y_test[n1 + i] = 1#目标
    return np.array(x_test), y_test

nb_classes = 2
path1 = '..\\testset\\noClusterKDEpicture_multiKde\\'#恒星/噪声路径
path2 = '..\\testset\\ClusterKDEpicture_multiKde\\'#目标路径
s1Path = '..\\model\\simulate_best_weights_v1.0.h5'#单真实数据
s2Path = '..\\model\\s2_best_weights_v1.0.h5'
E0Path = '..\\model\\E1_best_weights_v0.h5'
E1Path = '..\\model\\E1_best_weights_v1.h5'
E2Path = '..\\model\\E1_best_weights_v2.h5'
E3Path = '..\\model\\E1_best_weights_v3.h5'
E4Path = '..\\model\\E1_best_weights_v4.h5'
E5Path = '..\\model\\E1_best_weights_v5.h5'
E6Path = '..\\model\\E1_best_weights_v6.h5'
E7Path = '..\\model\\E1_best_weights_v7.h5'
E8Path = '..\\model\\E1_best_weights_v8.h5'
E9Path = '..\\model\\E1_best_weights_v9.h5'

singlePath = '..\\model\\single_best_weights_v1.0.h5'
data, labels = load_data(path1, path2)
data1, labels1 = load_data('..\\testset\\noClusterKDEpicture\\no_cluster\\loc\\', '..\\testset\\ClusterKDEpicture\\cluster\\loc\\')
data1= np.reshape(data1, (data1.shape[0],60,60,1))
#y_test = utils.to_categorical(y_test)

modelS1 = load_model(s1Path)
modelS2 = load_model(s2Path)
modelE0 = load_model(E0Path)
modelE1 = load_model(E1Path)
modelE2 = load_model(E2Path)
modelE3 = load_model(E3Path)
modelE4 = load_model(E4Path)
modelE5 = load_model(E5Path)
modelE6 = load_model(E6Path)
modelE7 = load_model(E7Path)
modelE8 = load_model(E8Path)
modelE9 = load_model(E9Path)

modelSingle = load_model(singlePath)
#%%
data, labels = load_data(path1, path2)
S1Result = modelS1.predict_classes(data)
S2Result = modelS2.predict_classes(data)
E0Result = modelE0.predict(data)
E1Result = modelE1.predict(data)
E2Result = modelE2.predict(data)
E3Result = modelE3.predict(data)
E4Result = modelE4.predict(data)
E5Result = modelE4.predict(data)
E6Result = modelE6.predict(data)
E7Result = modelE7.predict(data)
E8Result = modelE8.predict(data)
E9Result = modelE9.predict(data)
singleResult = modelSingle.predict_classes(data1)

EResult = np.argmax(E0Result+E1Result+E2Result+E3Result+E4Result+E5Result+E6Result+E7Result+E8Result+E9Result, axis=1)

#%%
import model_tools
print('simulatedata')
model_tools.model_metrics(labels, S1Result)
print('real+simulatedata')
model_tools.model_metrics(labels, S2Result)
print('Ensemble')
model_tools.model_metrics(labels, EResult)
#%%
p_S1Result = modelS1.predict(data)
p_S2Result = modelS2.predict(data)
p_EResult = E0Result+E1Result+E2Result+E3Result+E4Result+E5Result+E6Result+E7Result+E8Result+E9Result

from sklearn.metrics import roc_auc_score, roc_curve #导入ROC曲线函数
#    y = y_test
#    yp = preds_p[:,1]
#    yp1 = preds_p1[:,1]

auc1 = roc_auc_score(labels, p_S1Result[:,1])
auc2 = roc_auc_score(labels, p_S2Result[:,1])
auc3 = roc_auc_score(labels, p_EResult[:,1])
print('AUC1 = ', auc1)
print('AUC2 = ', auc2)
print('AUC3 = ', auc3)
plt.show()
#import matplotlib.pyplot as plt #导入作图库
plt.rcParams['savefig.dpi'] = 100 #图片像素
plt.rcParams['figure.dpi'] = 100 #分辨率
#plt.show() #显示作图结果
fpr1, tpr1, thresholds1 = roc_curve(labels, p_S1Result[:,1], pos_label=1)
plt.plot(fpr1, tpr1, linewidth=2, label = 'S1', c= 'r') #, color=i+1olor 作出ROC曲线 , AUC='+str('%.4f'%auc)
fpr1, tpr1, thresholds1 = roc_curve(labels, p_S2Result[:,1], pos_label=1)
plt.plot(fpr1, tpr1, linewidth=2, label = 'S2', c= 'g') #, color=i+1olor 作出ROC曲线 , AUC='+str('%.4f'%auc)
fpr1, tpr1, thresholds1 = roc_curve(labels, p_EResult[:,1], pos_label=1)
plt.plot(fpr1, tpr1, linewidth=2, label = 'E1', c= 'b') #, color=i+1olor 作出ROC曲线 , AUC='+str('%.4f'%auc)
plt.ylim(0,1.05) #边界范围
plt.xlim(0,1.05) #边界范围
plt.legend(loc=4) #图例
plt.show() #显示作图结果
#    plt.rcParams['savefig.dpi'] = 100 #图片像素
#    plt.rcParams['figure.dpi'] = 100 #分辨率
#pic = plt.figure(figsize=(10,7))
#plt.savefig('..\\picture\\ROC_'+de[f]+'.png')
#
#simulatedata
#acc 0.8883720930232558
#recall_score 0.8434782608695652
#prec_score 0.941747572815534
#f1_score 0.8899082568807339
#real+simulatedata
#acc 0.9162790697674419
#recall_score 0.9652173913043478
#prec_score 0.888
#f1_score 0.9249999999999999
#Ensemble
#acc 0.9441860465116279
#recall_score 0.991304347826087
#prec_score 0.912
#f1_score 0.9500000000000001